import os

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from matplotlib import pyplot as plt


path_tools = os.path.abspath(os.path.join(BASE_DIR, 'tools', 'common_tools.py'))
assert os.path.exists(path_tools), "{}不存在，请将common_tools.py文件放到{}".format(path_tools, os.path.dirname(path_tools))

import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + "..")
sys.path.append(hello_pytorch_DIR)

from deepeye.tools.common_tools import transform_invert, set_seed

set_seed(3)  # 设置随机种子 0,1,2,3 设置不同卷积核的权值，卷积核不同的权值代表不同的模式，不同的特征选择器，输出的特征图也就不一样

# ================================= load image ==================================
path_img = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'images', 'lena.png')
img = Image.open(path_img).convert('RGB')  # 0~255

# RGB image convert to tensor
img_transform = transforms.Compose([transforms.ToTensor()])
img_tensor = img_transform(img)
img_tensor.unsqueeze_(dim=0)  # C*H*W  to B*C*H*W

# ================================= create convolution layer ==================================

# ============ 2d
# flag = 0
flag = 1
if flag:
    conv_layer = nn.Conv2d(3, 1, 3)  # input:(i, o, size) weight:(o, i, h, w)
    nn.init.xavier_normal_(conv_layer.weight.data)

    # calculation
    img_conv = conv_layer(img_tensor)

# ============ transposed
flag = 0
# flag = 1
if flag:
    conv_layer = nn.ConvTranspose2d(3, 1, 3, stride=2)   # input:(i, o, size)
    nn.init.xavier_normal_(conv_layer.weight.data)

    # calculation
    img_conv = conv_layer(img_tensor)


# ================================= visualization ==================================
print("卷积前尺寸:{}\n卷积后尺寸：{}".format(img_tensor.shape, img_conv.shape))
img_conv = transform_invert(img_conv[0, 0:1, ...], img_transform)
img_raw = transform_invert(img_tensor.squeeze(), img_transform)
plt.subplot(122).imshow(img_conv, cmap='gray')
plt.subplot(121).imshow(img_raw)
plt.show()



